{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch.nn as nn\n",
    "import torch"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "class MyCell(torch.nn.Module):\n",
    "    def __init__(self):\n",
    "        super(MyCell, self).__init__()\n",
    "        self.linear = torch.nn.Linear(4, 4)\n",
    "\n",
    "    def forward(self, x, h):\n",
    "        new_h = torch.tanh(self.linear(x) + h)\n",
    "        return new_h"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [],
   "source": [
    "my_cell = MyCell()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [],
   "source": [
    "x, h = torch.rand(4, 4), torch.rand(4, 4)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "traced_cell = torch.jit.trace(my_cell, (x, h))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "TracedModule[MyCell](\n",
       "  original_name=MyCell\n",
       "  (linear): TracedModule[Linear](original_name=Linear)\n",
       ")"
      ]
     },
     "execution_count": 14,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "traced_cell"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[ 0.4238, -0.0524,  0.5719,  0.4747],\n",
       "        [-0.0059, -0.3625,  0.2658,  0.7130],\n",
       "        [ 0.4532,  0.6390,  0.6385,  0.6584]],\n",
       "       grad_fn=<DifferentiableGraphBackward>)"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "traced_cell(x, h)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "graph(%self : ClassType<MyCell>,\n",
       "      %input : Float(3, 4),\n",
       "      %h : Float(3, 4)):\n",
       "  %1 : ClassType<Linear> = prim::GetAttr[name=\"linear\"](%self)\n",
       "  %weight : Tensor = prim::GetAttr[name=\"weight\"](%1)\n",
       "  %bias : Tensor = prim::GetAttr[name=\"bias\"](%1)\n",
       "  %6 : Float(4, 4) = aten::t(%weight), scope: MyCell/Linear[linear] # /home/jibin/.local/lib/python3.6/site-packages/torch/nn/functional.py:1370:0\n",
       "  %7 : int = prim::Constant[value=1](), scope: MyCell/Linear[linear] # /home/jibin/.local/lib/python3.6/site-packages/torch/nn/functional.py:1370:0\n",
       "  %8 : int = prim::Constant[value=1](), scope: MyCell/Linear[linear] # /home/jibin/.local/lib/python3.6/site-packages/torch/nn/functional.py:1370:0\n",
       "  %9 : Float(3, 4) = aten::addmm(%bias, %input, %6, %7, %8), scope: MyCell/Linear[linear] # /home/jibin/.local/lib/python3.6/site-packages/torch/nn/functional.py:1370:0\n",
       "  %10 : int = prim::Constant[value=1](), scope: MyCell # <ipython-input-2-c6e2cd8665ee>:7:0\n",
       "  %11 : Float(3, 4) = aten::add(%9, %h, %10), scope: MyCell # <ipython-input-2-c6e2cd8665ee>:7:0\n",
       "  %12 : Float(3, 4) = aten::tanh(%11), scope: MyCell # <ipython-input-2-c6e2cd8665ee>:7:0\n",
       "  return (%12)"
      ]
     },
     "execution_count": 8,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "traced_cell.graph"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'import __torch__\\nimport __torch__.torch.nn.modules.linear\\ndef forward(self,\\n    input: Tensor,\\n    h: Tensor) -> Tensor:\\n  _0 = self.linear\\n  weight = _0.weight\\n  bias = _0.bias\\n  _1 = torch.addmm(bias, input, torch.t(weight), beta=1, alpha=1)\\n  return torch.tanh(torch.add(_1, h, alpha=1))\\n'"
      ]
     },
     "execution_count": 9,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "traced_cell.code"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [],
   "source": [
    "class MyDecisionGate(torch.nn.Module):\n",
    "    def forward(self, x):\n",
    "        if x.sum() > 0:\n",
    "            return x\n",
    "        else:\n",
    "            return -x"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [],
   "source": [
    "class MyCell(torch.nn.Module):\n",
    "    def __init__(self, dg):\n",
    "        super(MyCell, self).__init__()\n",
    "        self.dg = dg\n",
    "        self.linear = torch.nn.Linear(4, 4)\n",
    "\n",
    "    def forward(self, x, h):\n",
    "        new_h = torch.tanh(self.dg(self.linear(x)) + h)\n",
    "        return new_h"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/jibin/.local/lib/python3.6/site-packages/ipykernel_launcher.py:3: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!\n",
      "  This is separate from the ipykernel package so we can avoid doing imports until\n"
     ]
    }
   ],
   "source": [
    "my_cell = MyCell(MyDecisionGate())\n",
    "traced_cell = torch.jit.trace(my_cell, (x, h))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "import __torch__.___torch_mangle_0\n",
      "import __torch__\n",
      "import __torch__.torch.nn.modules.linear.___torch_mangle_1\n",
      "def forward(self,\n",
      "    input: Tensor,\n",
      "    h: Tensor) -> Tensor:\n",
      "  _0 = self.linear\n",
      "  weight = _0.weight\n",
      "  bias = _0.bias\n",
      "  x = torch.addmm(bias, input, torch.t(weight), beta=1, alpha=1)\n",
      "  _1 = torch.tanh(torch.add(torch.neg(x), h, alpha=1))\n",
      "  return _1\n",
      "\n"
     ]
    }
   ],
   "source": [
    "print(traced_cell.code)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "metadata": {},
   "outputs": [],
   "source": [
    "scripted_gate = torch.jit.script(MyDecisionGate())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "metadata": {},
   "outputs": [],
   "source": [
    "my_cell = MyCell(scripted_gate)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "import __torch__.___torch_mangle_3\n",
      "import __torch__.___torch_mangle_2\n",
      "import __torch__.torch.nn.modules.linear.___torch_mangle_4\n",
      "def forward(self,\n",
      "    x: Tensor,\n",
      "    h: Tensor) -> Tensor:\n",
      "  _0 = self.linear\n",
      "  _1 = _0.weight\n",
      "  _2 = _0.bias\n",
      "  if torch.eq(torch.dim(x), 2):\n",
      "    _3 = torch.__isnot__(_2, None)\n",
      "  else:\n",
      "    _3 = False\n",
      "  if _3:\n",
      "    bias = ops.prim.unchecked_unwrap_optional(_2)\n",
      "    ret = torch.addmm(bias, x, torch.t(_1), beta=1, alpha=1)\n",
      "  else:\n",
      "    output = torch.matmul(x, torch.t(_1))\n",
      "    if torch.__isnot__(_2, None):\n",
      "      bias0 = ops.prim.unchecked_unwrap_optional(_2)\n",
      "      output0 = torch.add_(output, bias0, alpha=1)\n",
      "    else:\n",
      "      output0 = output\n",
      "    ret = output0\n",
      "  _4 = torch.gt(torch.sum(ret, dtype=None), 0)\n",
      "  if bool(_4):\n",
      "    _5 = ret\n",
      "  else:\n",
      "    _5 = torch.neg(ret)\n",
      "  return torch.tanh(torch.add(_5, h, alpha=1))\n",
      "\n"
     ]
    }
   ],
   "source": [
    "traced_cell = torch.jit.script(my_cell)\n",
    "print(traced_cell.code)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.8"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
